Skip to content

Intermediate checkpointing for sequential calibration#1152

Open
sugunav14 wants to merge 7 commits intomainfrom
svelury/seq-calib-save-restore
Open

Intermediate checkpointing for sequential calibration#1152
sugunav14 wants to merge 7 commits intomainfrom
svelury/seq-calib-save-restore

Conversation

@sugunav14
Copy link
Copy Markdown
Contributor

@sugunav14 sugunav14 commented Mar 31, 2026

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Sequential quantization calibration supports checkpoint save/resume with configurable directory and interval.
    • HuggingFace models can persist sequential calibration checkpoints via their standard save flow.
  • Bug Fixes / Behavior

    • Checkpoint settings warn and are ignored when sequential calibration is disabled.
    • Checkpoint progress is cleaned up after calibration completion or resume.
  • Tests

    • Added tests for checkpoint save, resume, registry, warnings, and cleanup behavior.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 31, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Adds sequential calibration checkpointing: new config options, checkpoint registry and saver selection, model metadata persistence for resume, collector state-machine changes to support resuming and warm-up, HuggingFace saver plugin, and unit tests validating save/resume behavior.

Changes

Cohort / File(s) Summary
Configuration
modelopt/torch/quantization/config.py
Added `sequential_checkpoint_dir: str
Checkpoint Utilities
modelopt/torch/quantization/utils/checkpoint.py
New module introducing SEQ_CALIB_PROGRESS_ATTR, saver registry (register_seq_calib_checkpoint_saver), saver selection (get_checkpoint_saver), resume detection (detect_sequential_resume_layer), checkpoint decision logic (should_save_seq_calib_checkpoint), and checkpoint persistence (save_sequential_checkpoint).
Metadata Propagation
modelopt/torch/quantization/conversion.py
Persist and restore seq_calib_progress between model attribute (_seq_calib_progress) and checkpoint metadata during save/restore.
Calibration Mode Handling
modelopt/torch/quantization/mode.py
Calibration wrapper extracts sequential_checkpoint_dir and sequential_checkpoint_interval, warns if provided without sequential mode, and forwards them into sequential_calibrate when sequential is enabled.
Sequential Calibration Core
modelopt/torch/quantization/model_calib.py
sequential_calibrate signature now accepts checkpoint_dir and checkpoint_interval; integrates resume detection, prepares collector for resume, iterates from resume layer, conditionally saves rolling checkpoints, and removes progress attribute on cleanup.
Activation Collector
modelopt/torch/quantization/utils/activation_collector.py
Constrained layer mode to Literal values; centralized _set_layer_mode; _patch_all_layers accepts layer_output_metas; added prepare_for_resume (warm-up + validation) and get_layer_output_metas.
Plugin: HuggingFace Saver
modelopt/torch/quantization/plugins/huggingface.py
Added _save_hf_checkpoint and registered it with the checkpoint saver registry for supported HF models.
Tests
tests/unit/torch/quantization/test_sequential_calibrate.py
New comprehensive tests for saver registry, checkpoint triggering, resume detection/validation, metadata serialization, cleanup, and warning semantics.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant SequentialCalib as sequential_calibrate
    participant Collector as LayerActivationCollector
    participant CheckpointUtils
    participant Saver as ModelCheckpointSaver
    participant Model

    User->>SequentialCalib: start(calib_func, checkpoint_dir, interval)
    SequentialCalib->>CheckpointUtils: detect_sequential_resume_layer(Model, num_layers)
    CheckpointUtils->>Model: getattr(_seq_calib_progress)
    Model-->>CheckpointUtils: progress or None
    CheckpointUtils-->>SequentialCalib: resume_idx, metadata

    alt resume available
        SequentialCalib->>Collector: prepare_for_resume(resume_idx, forward_loop)
        Collector->>Collector: _run_warmup_capture / set modes
    end

    loop per-layer from resume_idx
        SequentialCalib->>Collector: _set_layer_states(layer_idx)
        Collector->>Model: forward(pass)
        SequentialCalib->>CheckpointUtils: should_save_seq_calib_checkpoint(layer_idx,...)
        alt should save
            SequentialCalib->>Collector: get_layer_output_metas(up_to_idx)
            Collector-->>SequentialCalib: metas
            SequentialCalib->>CheckpointUtils: save_sequential_checkpoint(Model, layer_idx, total, checkpoint_dir, metas)
            CheckpointUtils->>CheckpointUtils: get_checkpoint_saver(Model)
            CheckpointUtils->>Model: setattr(_seq_calib_progress, payload)
            CheckpointUtils->>Saver: save_fn(Model, checkpoint_dir)
            Saver->>Model: persist checkpoint
        end
    end

    SequentialCalib->>Model: delattr(_seq_calib_progress) (cleanup)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 74.19% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "Intermediate checkpointing for sequential calibration" directly and clearly summarizes the main change—adding checkpoint save/resume functionality to sequential calibration.
Security Anti-Patterns ✅ Passed The pull request does not introduce security anti-patterns as defined in SECURITY.md. The only torch.load(..., weights_only=False) call is in test code, which is explicitly exempted from security coding practices.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch svelury/seq-calib-save-restore

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 1, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1152/

Built to branch gh-pages at 2026-04-01 22:18 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 91.94631% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 54.64%. Comparing base (ada1e26) to head (9a017ef).
⚠️ Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
...t/torch/quantization/utils/activation_collector.py 86.48% 10 Missing ⚠️
modelopt/torch/quantization/conversion.py 83.33% 1 Missing ⚠️
modelopt/torch/quantization/plugins/huggingface.py 75.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (ada1e26) and HEAD (9a017ef). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (ada1e26) HEAD (9a017ef)
2 0
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1152       +/-   ##
===========================================
- Coverage   70.18%   54.64%   -15.54%     
===========================================
  Files         230      349      +119     
  Lines       26080    39895    +13815     
===========================================
+ Hits        18304    21802     +3498     
- Misses       7776    18093    +10317     
Flag Coverage Δ
unit 54.64% <91.94%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
@sugunav14 sugunav14 marked this pull request as ready for review April 1, 2026 21:24
@sugunav14 sugunav14 requested a review from a team as a code owner April 1, 2026 21:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/conversion.py`:
- Around line 143-145: The seq_calib_progress metadata is restored after the
fast-path return so the resume marker is lost; modify the fast-path in
convert_quantized_model (or the function containing the quantizer_state
fast-path) to set setattr(model, SEQ_CALIB_PROGRESS_ATTR,
metadata["seq_calib_progress"]) (guarded by "seq_calib_progress" in metadata)
before any early return that returns quantizer_state via extra_state, ensuring
the metadata restoration happens regardless of taking the fast-path.

In `@modelopt/torch/quantization/utils/activation_collector.py`:
- Around line 334-345: get_layer_output_metas currently returns the stored
output_meta objects verbatim (from _extract_output_meta) which may include
torch.device info and cause cross-device errors on resume; update
get_layer_output_metas (and the complementary loading path in _seq_calib) to
produce device-agnostic serializable metas: either strip/convert any
torch.device fields to a neutral representation (e.g. store device as string
like "cpu") or move tensors to a canonical device (cpu) before returning, and
ensure the loader remaps that neutral/device-string back to the current runtime
device when rebuilding state; look for references to _decoder_layers,
_LAYER_ATTR, state.output_meta, and _seq_calib to implement symmetric save (in
get_layer_output_metas) and load remapping so forward passes never receive
tensors bound to a stale device.

In `@modelopt/torch/quantization/utils/checkpoint.py`:
- Around line 104-113: The function should_save_seq_calib_checkpoint currently
does a modulo with checkpoint_interval which raises for zero and misbehaves for
negatives; add an upfront guard in should_save_seq_calib_checkpoint to reject
non-positive intervals by checking if checkpoint_interval is not None and
checkpoint_interval > 0 and raise a ValueError (with a clear message referencing
checkpoint_interval) before performing the modulo, so the later logic that uses
(layer_idx + 1) % checkpoint_interval and the other checks can remain unchanged.
- Around line 83-101: The persisted progress payload (stored under
SEQ_CALIB_PROGRESS_ATTR and loaded into progress) must be validated before use:
ensure progress is a dict and contains integer keys "completed_layer_idx" and
"total_layers", then verify completed_layer_idx is within [-1, num_layers - 1]
and that total_layers equals num_layers; if any check fails, raise a clear
ValueError (or return 0, None) instead of proceeding. Update the logic around
the progress variable, completed_layer and saved_total to validate types and
ranges before computing resume_from, printing via print_rank_0, or returning
layer_output_metas.

In `@tests/unit/torch/quantization/test_sequential_calibrate.py`:
- Around line 947-964: Replace the duplicated logic in
test_update_quantize_metadata_includes_progress with a real call to
update_quantize_metadata: set the SEQ_CALIB_PROGRESS_ATTR on the model as you
do, create an empty metadata dict (or config expected by
update_quantize_metadata), call update_quantize_metadata(model, metadata) (or
the correct signature of update_quantize_metadata) so it picks up
SEQ_CALIB_PROGRESS_ATTR, then assert metadata["seq_calib_progress"] equals the
progress value and finally delattr the SEQ_CALIB_PROGRESS_ATTR; reference the
test name test_update_quantize_metadata_includes_progress, the function
update_quantize_metadata, and the attribute SEQ_CALIB_PROGRESS_ATTR when
locating the code to change.
- Around line 984-987: Add an inline comment next to the torch.load call
explaining why using weights_only=False is safe: note that the buffer `buf` is
produced locally by `torch.save(progress, buf)` (not from external input) so
deserializing with `weights_only=False` does not violate the security guideline;
annotate the `loaded = torch.load(buf, weights_only=False)` line with this
justification referencing `buf`, `progress`, and `torch.load`.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7693e81d-61c3-4338-9c0f-67d1b730e51b

📥 Commits

Reviewing files that changed from the base of the PR and between ada1e26 and b9035fb.

📒 Files selected for processing (8)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/conversion.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/activation_collector.py
  • modelopt/torch/quantization/utils/checkpoint.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py

Comment on lines +83 to +101
progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
if progress is None:
return 0, None

completed_layer = progress["completed_layer_idx"]
saved_total = progress["total_layers"]

if saved_total != num_layers:
raise ValueError(
f"Checkpoint was saved with {saved_total} layers but model has "
f"{num_layers} layers. Cannot resume."
)

resume_from = completed_layer + 1
print_rank_0(
f"Resuming sequential calibration from layer {resume_from} "
f"(layers 0..{completed_layer} already calibrated)"
)
return resume_from, progress.get("layer_output_metas", {})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the persisted progress payload before using it.

This metadata comes back from disk, so malformed values currently turn into KeyError or an impossible resume point later in the flow. Please validate the schema and enforce completed_layer_idx within [-1, num_layers - 1] before logging or returning it.

🛠️ Suggested guard
 def detect_sequential_resume_layer(model: nn.Module, num_layers: int) -> tuple[int, dict | None]:
     """Read checkpoint progress from the model and return ``(resume_layer_idx, layer_output_metas)``.
@@
     progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
     if progress is None:
         return 0, None
-
-    completed_layer = progress["completed_layer_idx"]
-    saved_total = progress["total_layers"]
+    if not isinstance(progress, dict):
+        raise ValueError("Malformed sequential calibration checkpoint metadata.")
+    try:
+        completed_layer = int(progress["completed_layer_idx"])
+        saved_total = int(progress["total_layers"])
+    except (KeyError, TypeError, ValueError) as exc:
+        raise ValueError("Malformed sequential calibration checkpoint metadata.") from exc
 
     if saved_total != num_layers:
         raise ValueError(
             f"Checkpoint was saved with {saved_total} layers but model has "
             f"{num_layers} layers. Cannot resume."
         )
+    if completed_layer < -1 or completed_layer >= num_layers:
+        raise ValueError(
+            f"Checkpoint completed_layer_idx={completed_layer} is out of range "
+            f"for a model with {num_layers} layers."
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
if progress is None:
return 0, None
completed_layer = progress["completed_layer_idx"]
saved_total = progress["total_layers"]
if saved_total != num_layers:
raise ValueError(
f"Checkpoint was saved with {saved_total} layers but model has "
f"{num_layers} layers. Cannot resume."
)
resume_from = completed_layer + 1
print_rank_0(
f"Resuming sequential calibration from layer {resume_from} "
f"(layers 0..{completed_layer} already calibrated)"
)
return resume_from, progress.get("layer_output_metas", {})
progress = getattr(model, SEQ_CALIB_PROGRESS_ATTR, None)
if progress is None:
return 0, None
if not isinstance(progress, dict):
raise ValueError("Malformed sequential calibration checkpoint metadata.")
try:
completed_layer = int(progress["completed_layer_idx"])
saved_total = int(progress["total_layers"])
except (KeyError, TypeError, ValueError) as exc:
raise ValueError("Malformed sequential calibration checkpoint metadata.") from exc
if saved_total != num_layers:
raise ValueError(
f"Checkpoint was saved with {saved_total} layers but model has "
f"{num_layers} layers. Cannot resume."
)
if completed_layer < -1 or completed_layer >= num_layers:
raise ValueError(
f"Checkpoint completed_layer_idx={completed_layer} is out of range "
f"for a model with {num_layers} layers."
)
resume_from = completed_layer + 1
print_rank_0(
f"Resuming sequential calibration from layer {resume_from} "
f"(layers 0..{completed_layer} already calibrated)"
)
return resume_from, progress.get("layer_output_metas", {})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/checkpoint.py` around lines 83 - 101, The
persisted progress payload (stored under SEQ_CALIB_PROGRESS_ATTR and loaded into
progress) must be validated before use: ensure progress is a dict and contains
integer keys "completed_layer_idx" and "total_layers", then verify
completed_layer_idx is within [-1, num_layers - 1] and that total_layers equals
num_layers; if any check fails, raise a clear ValueError (or return 0, None)
instead of proceeding. Update the logic around the progress variable,
completed_layer and saved_total to validate types and ranges before computing
resume_from, printing via print_rank_0, or returning layer_output_metas.

Comment on lines +984 to +987
buf = io.BytesIO()
torch.save(progress, buf)
buf.seek(0)
loaded = torch.load(buf, weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists
git ls-files tests/unit/torch/quantization/test_sequential_calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 124


🏁 Script executed:

# Read the specified lines and surrounding context
sed -n '980,995p' tests/unit/torch/quantization/test_sequential_calibrate.py

Repository: NVIDIA/Model-Optimizer

Length of output: 564


Add inline comment to document why weights_only=False is safe.

Line 987 requires an inline comment justifying weights_only=False per security guidelines. The buffer is locally generated via torch.save() and never comes from external input, which satisfies the exception criteria.

Suggested fix
         buf = io.BytesIO()
         torch.save(progress, buf)
         buf.seek(0)
-        loaded = torch.load(buf, weights_only=False)
+        # Safe here: `buf` is produced by `torch.save` in this test and never comes from user input.
+        loaded = torch.load(buf, weights_only=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_sequential_calibrate.py` around lines 984
- 987, Add an inline comment next to the torch.load call explaining why using
weights_only=False is safe: note that the buffer `buf` is produced locally by
`torch.save(progress, buf)` (not from external input) so deserializing with
`weights_only=False` does not violate the security guideline; annotate the
`loaded = torch.load(buf, weights_only=False)` line with this justification
referencing `buf`, `progress`, and `torch.load`.

Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/checkpoint.py (1)

161-165: Consider using the constant for consistency.

Line 161 uses the literal _seq_calib_progress while line 83 uses getattr(model, SEQ_CALIB_PROGRESS_ATTR, ...). Using setattr with the constant would improve maintainability.

♻️ Suggested change
-    model._seq_calib_progress = {
+    setattr(model, SEQ_CALIB_PROGRESS_ATTR, {
         "completed_layer_idx": completed_layer_idx,
         "total_layers": total_layers,
         "layer_output_metas": layer_output_metas,
-    }
+    })
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/checkpoint.py` around lines 161 - 165,
Replace the literal attribute assignment model._seq_calib_progress with a
setattr using the existing SEQ_CALIB_PROGRESS_ATTR constant to match the getattr
usage elsewhere; specifically, set the attribute on model via setattr(model,
SEQ_CALIB_PROGRESS_ATTR, {...}) using the same keys (completed_layer_idx,
total_layers, layer_output_metas) so the code consistently references
SEQ_CALIB_PROGRESS_ATTR instead of the hard-coded "_seq_calib_progress".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/utils/activation_collector.py`:
- Around line 452-456: The code flips preceding layers to "skip" via
_set_layer_mode before calling _validate_skip_metas, which can leave the
collector partially modified on validation failure; change the logic to validate
the skip metas before flipping modes or, if you must flip first, catch
exceptions from _validate_skip_metas and restore all layers in preceding back to
"original" using _set_layer_mode to avoid a half-resumed state; reference the
surrounding calls _run_warmup_capture, _set_layer_mode(preceding), and
_validate_skip_metas to locate where to add the pre-validation check or the
exception handler+restore.
- Around line 435-443: prepare_for_resume() uses resume_layer_idx without
validating it; add an upfront range check after confirming self._patched to
ensure 0 <= resume_layer_idx <= total_layers (use the class' layer count, e.g.,
len(self.layers) or self.num_layers) and raise a ValueError with a clear message
if out of range. Keep the existing resume_layer_idx == 0 early-return behavior
but perform the validation before any state mutation (before assigning k or
computing preceding) so negative indices or overly large indices are rejected
immediately.

---

Nitpick comments:
In `@modelopt/torch/quantization/utils/checkpoint.py`:
- Around line 161-165: Replace the literal attribute assignment
model._seq_calib_progress with a setattr using the existing
SEQ_CALIB_PROGRESS_ATTR constant to match the getattr usage elsewhere;
specifically, set the attribute on model via setattr(model,
SEQ_CALIB_PROGRESS_ATTR, {...}) using the same keys (completed_layer_idx,
total_layers, layer_output_metas) so the code consistently references
SEQ_CALIB_PROGRESS_ATTR instead of the hard-coded "_seq_calib_progress".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: dfffce23-8e9d-4e21-aaf3-1d51916f4c88

📥 Commits

Reviewing files that changed from the base of the PR and between 5d2ef1a and 9a017ef.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/conversion.py
  • modelopt/torch/quantization/utils/activation_collector.py
  • modelopt/torch/quantization/utils/checkpoint.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/conversion.py

Comment on lines +435 to +443
if not self._patched:
raise RuntimeError(
"prepare_for_resume() requires _patch_all_layers() to be called first."
)
if resume_layer_idx == 0:
return

k = resume_layer_idx
preceding = range(k - 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate resume_layer_idx before using it.

Negative values will target the last layers instead of being rejected, and values past the end will fail only after resume setup has already started mutating state. Since this value is checkpoint-derived, please range-check it up front.

💡 Suggested fix
         if not self._patched:
             raise RuntimeError(
                 "prepare_for_resume() requires _patch_all_layers() to be called first."
             )
+        assert self._decoder_layers is not None
+        num_layers = len(self._decoder_layers)
+        if not 0 <= resume_layer_idx < num_layers:
+            raise ValueError(
+                f"resume_layer_idx must be in [0, {num_layers - 1}], got {resume_layer_idx}."
+            )
         if resume_layer_idx == 0:
             return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/activation_collector.py` around lines 435 -
443, prepare_for_resume() uses resume_layer_idx without validating it; add an
upfront range check after confirming self._patched to ensure 0 <=
resume_layer_idx <= total_layers (use the class' layer count, e.g.,
len(self.layers) or self.num_layers) and raise a ValueError with a clear message
if out of range. Keep the existing resume_layer_idx == 0 early-return behavior
but perform the validation before any state mutation (before assigning k or
computing preceding) so negative indices or overly large indices are rejected
immediately.

Comment on lines +452 to +456
self._run_warmup_capture(k - 1, forward_loop)

for i in preceding:
self._set_layer_mode(i, "skip")
self._validate_skip_metas(preceding)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rollback the mode flip if resume validation fails.

_validate_skip_metas() runs after the earlier layers have already been moved to skip, so a missing meta leaves the collector in a half-resumed state on the exception path. Please validate before the flip, or restore the touched layers to original when validation fails.

💡 Suggested fix
-        self._run_warmup_capture(k - 1, forward_loop)
-
-        for i in preceding:
-            self._set_layer_mode(i, "skip")
-        self._validate_skip_metas(preceding)
+        try:
+            self._run_warmup_capture(k - 1, forward_loop)
+            self._validate_skip_metas(preceding)
+        except Exception:
+            for i in preceding:
+                self._set_layer_mode(i, "original")
+            state = self._decoder_layers[k - 1]._seq_calib
+            state.mode = "original"
+            state.collected_inputs = []
+            raise
+
+        for i in preceding:
+            self._set_layer_mode(i, "skip")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/activation_collector.py` around lines 452 -
456, The code flips preceding layers to "skip" via _set_layer_mode before
calling _validate_skip_metas, which can leave the collector partially modified
on validation failure; change the logic to validate the skip metas before
flipping modes or, if you must flip first, catch exceptions from
_validate_skip_metas and restore all layers in preceding back to "original"
using _set_layer_mode to avoid a half-resumed state; reference the surrounding
calls _run_warmup_capture, _set_layer_mode(preceding), and _validate_skip_metas
to locate where to add the pre-validation check or the exception
handler+restore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant